import scipy
import numpy as np
from scipy.integrate import quad
from numpy.lib.scimath import sqrt as csqrt

def mp_pdf(x,ratio):
    """Probability density function of the Marchenko-Pastur (MP) distribution.
    w/o the point mass at 0 and variance of the entries = 1.

    Parameters
    -----------
        x : float or array-like
            Input on which to evaluate the density function. If array-like,
            understood as a sequence of inputs
        ratio : float
            represents the ratio d/n
    Returns
    ---------
        out : float or array-like, same shape as x
            probability density function for each element in x
    """
    #Computes the largest and smallest eigenvalues of Marchenko-Pastur
    min_eigval =  (np.float64(1) - np.sqrt(1/ratio))**2
    max_eigval =  (np.float64(1) + np.sqrt(1/ratio))**2
    x = np.asarray(x)
    out = np.zeros_like(x)
    idx = (x > min_eigval) & (x < max_eigval) #we're only consdering when x  \in [\lambda_min, \lambda_max], i.e. no point mass at 0
    tmp = np.sqrt( (max_eigval - x[idx]) * (x[idx] - min_eigval)  )
    out[idx] =  ratio* tmp / ( 2 * np.pi * x[idx] )
    return out


def volterra_Psi(max_iter, Delta, gamma, zeta, ratio, R, R_tilde):
  '''
  Computes the values of the Volterra Equation with Marcenko-Pastur
  For details see Section 9 "Numerical Simulations" in the paper.

  Parameters
  -----------
    Inputs
    -------
    int   max_iter   : number of values computed
    float Delta      : momentum parameter
    float gamma      : learning_rate or stepsize
    float zeta       : batch-size divided by number of samples (i.e. beta/n)
    float ratio      : number of features divided by number of sampels (i.e. d/n)
    float R          : normalization constant for signal (see Problem Setting in paper)
    float R_tilde    : normalization constant for noise (see Problem Setting in paper)

    Output
    ------
    list ps0_list    : list of function values of volterra with Marcenko-Pastur (ref eqn' in paper)

  '''


  Delta = np.maximum(Delta, 1e-8).astype('complex64') #for numerical stability
  min_eigval =  (np.float64(1)  - np.sqrt(1/ratio))**2
  max_eigval =  (np.float64(1)  + np.sqrt(1/ratio))**2
  
  #Number of grid points for the integral approximation
  n_gridMP = 4000

  #Computes the weights for the quadrature rule for integration
  sigmaSpace = np.cos(np.linspace(0, np.pi, n_gridMP) + np.pi / (2 * n_gridMP ) ) *( (max_eigval-min_eigval)/2.0 ) + ( (max_eigval+min_eigval) / 2.0 )
  quadweights = np.sin(np.linspace(0, np.pi, n_gridMP) + np.pi/(2*n_gridMP)) * ( (max_eigval-min_eigval) /2 ) * np.pi/n_gridMP
  mus = mp_pdf(sigmaSpace,ratio)

  def Omega(x):
      return (np.float64(1)- zeta*gamma*x + Delta)

  def Omega_sqd(x):
      return (np.float64(1)- zeta*gamma*x + Delta)**2

  def lambda_k(omega_sqd, k):
      discriminant = csqrt(omega_sqd * (omega_sqd - 4*Delta))
      if k==2:
          return ((-2.0*Delta + omega_sqd) + discriminant)/2.0
      else:
          return ((-2.0*Delta + omega_sqd) - discriminant)/2.0

  def kappa_k(x,k):
    omega_sqd = Omega_sqd(x)
    omega = Omega(x)
    if k == 2:
      lambda_2 = lambda_k(omega_sqd=omega_sqd, k=2)
      return (lambda_2*omega)/(lambda_2 + Delta)
    else:
      lambda_3 = lambda_k(omega_sqd=omega_sqd, k=3)
      return (lambda_3*omega)/(lambda_3 + Delta)
  
  #the following integrands will be multiplied by Marcenko-Pastur density
  
  def hk_integrand(x, k, t):
    omega_sqd = Omega_sqd(x)
    lambda_2 = lambda_k(omega_sqd=omega_sqd, k=2)
    lambda_3 = lambda_k(omega_sqd=omega_sqd, k=3)
    kappa_2 = kappa_k(k=2,x=x)
    kappa_3 = kappa_k(k=3,x=x)
    
    term1 =  2*(x**k)/(omega_sqd - 4*Delta)
    term2 = -Delta * gamma * zeta * x * (Delta**t)
    term3 = (0.5*(kappa_2 - Delta)**2) * (lambda_2 ** t)
    term4 = (0.5*(kappa_3 - Delta)**2) * (lambda_3 ** t)
    
    return term1 * (term2 + term3 + term4)
  
  
  def H2_integrand(x,t):
    omega_sqd = Omega_sqd(x)
    lambda_2 = lambda_k(omega_sqd=omega_sqd, k=2)
    lambda_3 = lambda_k(omega_sqd=omega_sqd, k=3)
    term1 = (2.0*(x**2.0))/(omega_sqd - 4.0*Delta)
    term2 = -(Delta**(t+1))
    term3 = (0.5) *(lambda_2**(t+1))
    term4 = (0.5) *(lambda_3**(t+1))
    return term1 * (term2 + term3 + term4)
  
  def psi0(H_2, h0, h1, ratio ,R, R_tilde, gamma,zeta):
    psi0_list = np.zeros(max_iter)
    H2_rev = H_2[::-1]
    for t in range(max_iter):
      term1 = 0.5 * R * h1[t]
      term2 = 0.5 * R_tilde * (h0[t])
      if t != 0:
        term3 = (gamma**2)*zeta*(1-zeta)*np.dot(H2_rev[-t:], psi0_list[:(t)])
      else:
        term3 = 0.0
      #terms 1,2,and 3 are reals so discarding complex part has no affect
      psi0_list[t] = np.real(term1 + term2 + term3) #cast to reals to turn off python warning
    return psi0_list
  
  def integrate_hk(k,t):
    hk_values = hk_integrand(x=sigmaSpace, k=k,t=t)
    hk_tally = 0
    for (weight, kernel, mu) in zip(quadweights, hk_values, mus):
      hk_tally += kernel * mu * weight
    return hk_tally
  
  def integrate_H2(t):
    H2_values = H2_integrand(x=sigmaSpace, t=t)
    H2_tally = 0
    for (weight, kernel, mu) in zip(quadweights, H2_values, mus):
      H2_tally += kernel * mu * weight
    return H2_tally

  h0 = np.zeros(max_iter).astype("complex64")
  h1 = np.zeros(max_iter).astype("complex64")
  H2 = np.zeros(max_iter).astype("complex64")

  #populate h0,h1,H_2, to be used to compute psi0(t) for t = 1, ..., max_iter
  for t in range(max_iter):
    h0[t] = integrate_hk(k=0,t=t)
    h1[t] = integrate_hk(k=1, t=t)
    H2[t] = integrate_H2(t)

  omega_sqd = Omega_sqd(0)
  lambda_2 = lambda_k(omega_sqd=omega_sqd, k=2)
  lambda_3 = lambda_k(omega_sqd=omega_sqd, k=3)
  kappa_2 = kappa_k(x=0, k=2)
  kappa_3 = kappa_k(x=0, k=3)
  
  term1 =  2/(omega_sqd - 4*Delta)
  term2 = -Delta * gamma * zeta * 0 * (Delta**t)
  term3 = (0.5*(kappa_2 - Delta)**2) * (lambda_2 ** t)
  term4 = (0.5*(kappa_3 - Delta)**2) * (lambda_3 ** t)
  
  h0 = h0 + (term1*(term2+term3+term4))*np.maximum(0.0, 1.0-ratio) #augment by delta-mass at zero

  psi0_list = psi0(H2, h0, h1, ratio, R, R_tilde, gamma,zeta)

  return psi0_list
